{
 "cells": [
  {
   "cell_type": "raw",
   "id": "b50fb661-62ae-4706-a86e-1b8f50f08f0d",
   "metadata": {},
   "source": [
    "---\n",
    "title: visualization\n",
    "subtitle: Plots for registration and 3D visualization\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a3d1f8e-5b89-41b3-889d-2b054f63c125",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a22f019-d10f-4c94-aa9e-fc6d32f74f4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aefed743-3d9c-4316-a376-04e24b04a636",
   "metadata": {},
   "source": [
    "## Overlay over predicted edges on target images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a452921-86fb-49c1-a3c6-749fad621a55",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "from io import BytesIO\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "from PIL import Image\n",
    "from skimage.feature import canny\n",
    "from torchvision.utils import make_grid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d5b1780-fa0a-4d81-8103-40572c938431",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "def _overlay_edges(target, pred, sigma, eps=1e-5):\n",
    "    pred = (pred - pred.min()) / (pred.max() - pred.min() + eps)\n",
    "    edges = canny(pred, sigma=sigma)\n",
    "    edges = np.ma.masked_where(~edges, edges)\n",
    "\n",
    "    buffer = BytesIO()\n",
    "    plt.subplot()\n",
    "    plt.imshow(target, cmap=\"gray\")\n",
    "    plt.imshow(edges, cmap=\"cool_r\", interpolation=\"none\", vmin=0.0, vmax=1.0)\n",
    "    plt.axis(\"off\")\n",
    "    plt.savefig(buffer, format=\"png\", bbox_inches=\"tight\", pad_inches=0, dpi=300)\n",
    "    arr = np.array(Image.open(buffer))\n",
    "    plt.close()\n",
    "    return arr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3dc560ca-c5f3-4569-990f-efa40d7c8481",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def overlay_edges(target, pred, sigma=1.5):\n",
    "    \"\"\"Generate edge overlays for a batch of targets and predictions.\"\"\"\n",
    "    edges = []\n",
    "    for i, p in zip(target, pred):\n",
    "        edge = _overlay_edges(i[0].cpu().numpy(), p[0].cpu().numpy(), sigma)\n",
    "        edges.append(edge)\n",
    "    edges = torch.from_numpy(np.stack(edges)).permute(0, -1, 1, 2)\n",
    "    edges = make_grid(edges).permute(1, 2, 0)\n",
    "    return edges"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab90aa75-5869-49d7-aaa2-5f9041295840",
   "metadata": {},
   "source": [
    "## Using PyVista to visualize 3D geometry"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65d197c9-dfbd-4029-bd23-05b4fda02a76",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import pyvista\n",
    "from torch.nn.functional import pad\n",
    "\n",
    "from diffpose.calibration import RigidTransform, perspective_projection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b5897c4-605e-44eb-aa78-097a273211bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "def fiducials_3d_to_projected_fiducials_3d(specimen, pose):\n",
    "    # Extrinsic camera matrix\n",
    "    extrinsic = (\n",
    "        specimen.lps2volume.inverse()\n",
    "        .compose(pose.inverse())\n",
    "        .compose(specimen.translate)\n",
    "        .compose(specimen.flip_xz)\n",
    "    )\n",
    "\n",
    "    # Intrinsic projection -> in 3D\n",
    "    x = perspective_projection(extrinsic, specimen.intrinsic, specimen.fiducials)\n",
    "    x = -specimen.focal_len * torch.einsum(\n",
    "        \"ij, bnj -> bni\",\n",
    "        specimen.intrinsic.inverse(),\n",
    "        pad(x, (0, 1), value=1),  # Convert to homogenous coordinates\n",
    "    )\n",
    "\n",
    "    # Some command-z\n",
    "    extrinsic = (\n",
    "        specimen.flip_xz.inverse().compose(specimen.translate.inverse()).compose(pose)\n",
    "    )\n",
    "    return extrinsic.transform_points(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "902878a6-16b1-4ea0-b4b7-3693c5e7607d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def fiducials_to_mesh(\n",
    "    specimen,\n",
    "    rotation=None,\n",
    "    translation=None,\n",
    "    parameterization=None,\n",
    "    convention=None,\n",
    "    detector=None,\n",
    "):\n",
    "    \"\"\"\n",
    "    Use camera matrices to get 2D projections of 3D fiducials for a given pose.\n",
    "    If the detector is passed, 2D projections will be filtered for those that lie\n",
    "    on the detector plane.\n",
    "    \"\"\"\n",
    "    # Location of fiducials in 3D\n",
    "    fiducials_3d = specimen.lps2volume.inverse().transform_points(specimen.fiducials)\n",
    "    fiducials_3d = pyvista.PolyData(fiducials_3d.squeeze().numpy())\n",
    "    if rotation is None and translation is None and parameterization is None:\n",
    "        return fiducials_3d\n",
    "\n",
    "    # Embedding of fiducials in 2D\n",
    "    pose = RigidTransform(rotation, translation, parameterization, convention, device=\"cpu\")\n",
    "    fiducials_2d = fiducials_3d_to_projected_fiducials_3d(specimen, pose)\n",
    "    fiducials_2d = fiducials_2d.squeeze().numpy()\n",
    "\n",
    "    # Optionally, only render 2D fiducials that lie on the detector plane\n",
    "    if detector is not None:\n",
    "        corners = detector.points.reshape(\n",
    "            detector[\"height\"][0], detector[\"width\"][0], 3\n",
    "        )[\n",
    "            [0, 0, -1, -1],\n",
    "            [0, -1, 0, -1],\n",
    "        ]\n",
    "        exclude = np.logical_or(\n",
    "            fiducials_2d < corners.min(0),\n",
    "            fiducials_2d > corners.max(0),\n",
    "        ).any(1)\n",
    "        fiducials_2d = fiducials_2d[~exclude]\n",
    "\n",
    "    fiducials_2d = pyvista.PolyData(fiducials_2d)\n",
    "    return fiducials_3d, fiducials_2d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87699278-de70-44c7-887e-7b63ba3f1981",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def lines_to_mesh(camera, fiducials_2d):\n",
    "    \"\"\"Draw lines from the camera to the 2D fiducials.\"\"\"\n",
    "    lines = []\n",
    "    for pt in fiducials_2d.points:\n",
    "        line = pyvista.Line(pt, camera.center)\n",
    "        lines.append(line)\n",
    "    return lines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba77dd5d-cd83-4a1e-8520-5d27d9b972fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "import nbdev\n",
    "\n",
    "nbdev.nbdev_export()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa4fbd2b-75ce-48d0-90ff-56ddb421a547",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
